# ------------------------------------------------------------------------------
# Produces behavioral reg tables for test, yield
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bash 06_behavioral_reg_tbl.sh
# ------------------------------------------------------------------------------

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(here)
library(yaml)
library(tidyverse)
library(glue)
library(testit) # assert()
library(broom) # tidy()
library(data.table) # setnames()
library(lfe) # felm
library(xtable)
library(stargazer)

temp <- here("code", "07_other_physician_errors", "temp")
u <- modules::use(here("lib", "util.R"))

# Load Data --------------------------------------------------------------------
message("Loading data...")
overnight_lab <- ""
paths <- read_yaml(here::here("lib", "filepaths.yml"))

cohort <- readRDS(glue(paths$analysis$test_cohort)) %>%
  mutate(yhat_full = p__ensemble__stent_or_cabg_010_day__tested) %>%
  filter(!exclude)

# New Subcat yhats -------------------------------------------------------------
val_cohort <- readRDS(glue(paths$analysis$full_cohort)) %>%
  filter(is.na(split))
lab_val <- readRDS(
    file.path(
      paths$modeling$dir, "prediction", "random", "lab", "scores_val_set.rds"
    )
) %>%
  setnames(
    "p__ensemble__stent_or_cabg_010_day__tested__lab", "yhat_lab"
  )
lvs_val <- readRDS(
    file.path(
      paths$modeling$dir, "prediction", "random", "lvs", "scores_val_set.rds"
    )
) %>%
  setnames(
    "p__ensemble__stent_or_cabg_010_day__tested__lvs", "yhat_lvs"
  )

  val_cohort <- val_cohort %>%
    u$safe_left_join(lab_val) %>%
    u$safe_left_join(lvs_val)

  lab_lvs_fit <- lm(
    stent_or_cabg_010_day ~ yhat_lab + yhat_lvs,
    data = filter(val_cohort, test_010_day)
  )

# Get subcategory predictions --------------------------------------------------
message("Getting subcategory yhats...")
subcats <- c("justcc", "represent", "dia", "prc", "dem", "lab", "lvs")
for(subcat in subcats){
  predictions <- readRDS(
    file.path(
      paths$modeling$dir, "prediction", "random", subcat, "scores_test_set.rds"
    )
  ) %>%
  setnames(
    glue("p__ensemble__stent_or_cabg_010_day__tested__{subcat}"),
    glue("yhat_{subcat}")
  )
  cohort <- cohort %>% u$safe_left_join(predictions)
}

cohort <- cohort %>% mutate(yhat_lvs_lab = predict(lab_lvs_fit, .))
subcats <- c(subcats, "lvs_lab") %>%
  setdiff(c("lab", "lvs"))

# Represantitive/salience Regressions ------------------------------------------
for(outcome in c("test_010_day", "stent_or_cabg_010_day")){
  message("Regressing ", outcome, " on subcomponent risk...")
  model_df <- copy(cohort)
  if(outcome == "stent_or_cabg_010_day"){
    model_df <- filter(model_df, test_010_day)
  }

  form_baseline <- reformulate(
    response = outcome,
    termlabels = c("yhat_full", "1 | 0 | 0 | ptid")
  )
  fit_baseline <- felm(form_baseline, model_df)

  form_symp <- reformulate(
    response = outcome,
    termlabels = c(
      "yhat_full", "yhat_justcc", "1 | 0 | 0 | ptid"
    )
  )
  fit_symp <- felm(form_symp, model_df)

  subcat_yhats <- glue("yhat_{subcats}") %>% setdiff("yhat_represent")
  form_subcats <- reformulate(
    response = outcome,
    termlabels = c(
      "yhat_full", subcat_yhats, "1 | 0 | 0 | ptid"
    )
  )
  fit_subcats <- felm(form_subcats, model_df)

  form_represent <- reformulate(
    response = outcome,
    termlabels = c(
      "yhat_full", "yhat_justcc", "yhat_represent", "1 | 0 | 0 | ptid"
    )
  )
  fit_represent <- felm(form_represent, model_df)

  var_labs <- c(
    yhat_full = "Predicted Risk, Full", yhat_justcc = "Symptoms",
    yhat_represent = "Representative Symptoms",
    yhat_dem = "Demographics",
    yhat_dia = "Prior Diagnoses", yhat_prc = "Prior Procedures",
    yhat_lvs_lab = "Prior Labs and Vital Signs"
  )

  reg_tbl <- stargazer(
    fit_baseline, fit_symp, fit_subcats, fit_represent,
    omit.stat = c("f", "ser", "ll", "aic", "adj.rsq"),
    order = names(var_labs), covariate.labels = var_labs
  )

  message("Saving salience regressions for ", outcome, "...")
  save_fp <- file.path(temp, glue("{outcome}__salience_regressions.tex"))
  write(reg_tbl, save_fp, append = FALSE)
}

message("Done.")
